# --------------------------------------------------------
# modified from Hora

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import yaml
import argparse

import sys
sys.path.append('../IsaacGymEnvs2/isaacgymenvs')

from ddim.models.diffusion_controlseq import  ModelInvDyn
# from isaacgymenvs.ddim.main import dict2namespace

try:
    from isaacgym.torch_utils import to_torch, unscale
except:
    pass

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace



class MLP(nn.Module):
    def __init__(self, units, input_size):
        super(MLP, self).__init__()
        layers = []
        for output_size in units:
            layers.append(nn.Linear(input_size, output_size))
            layers.append(nn.ELU())
            input_size = output_size
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)


class MLPMaskedNet(nn.Module):
    def __init__(self, units, target_joint_idx_tensor, actions_num):
        super(MLPMaskedNet, self).__init__()
        self.target_joint_idx_tensor = target_joint_idx_tensor
        # for each joint index in the target joint index tensor --- build its MLP model #
        # what's the observation of the actor's input? #
        
        self.actions_num = actions_num
        
        self.per_joint_input_dim = 3
        self.joint_idx_to_mlp = {}
        self.wm_history_length = 1
        
        # self.per_joint_units =
        self.target_joint_idx_list = self.target_joint_idx_tensor.detach().cpu().tolist()
        
        self.joint_mlp_modulelist = nn.ModuleList()
        self.mu_mlp_modulelist = nn.ModuleList()
        for joint_idx in self.target_joint_idx_list:
            self.joint_mlp_modulelist.append(MLP(units, self.per_joint_input_dim))
            self.mu_mlp_modulelist.append(torch.nn.Linear(units[-1], actions_num))
        
        
    def forward(self, x):
        
        # compensator observation #
        # bsz x ((32) * wm_history_length + 16)
        hist_obs, policy_out_actions = x[..., : -16], x[..., -16: ]
        hist_obs = hist_obs.contiguous().view(hist_obs.shape[0], self.wm_history_length, -1).contiguous()
        hist_qpos, hist_qtars = hist_obs[..., : 16], hist_obs[..., 16: ]
        target_joint_delta_actions = []
        mus = []
        for i_joint, joint_idx in enumerate(self.target_joint_idx_list):
            jt_hist_qpos, jt_hist_qtars = hist_qpos[..., joint_idx: joint_idx + 1], hist_qtars[..., joint_idx: joint_idx + 1]
            jt_act = policy_out_actions[..., joint_idx: joint_idx + 1]
            jt_hist_qpos = jt_hist_qpos.contiguous().view(jt_hist_qpos.shape[0], -1).contiguous()
            jt_hist_qtars = jt_hist_qtars.contiguous().view(jt_hist_qtars.shape[0], -1).contiguous()
            jt_compensator_input = torch.cat(
                [ jt_hist_qpos, jt_hist_qtars, jt_act ], dim=-1
            )
            jt_compensator_output = self.joint_mlp_modulelist[i_joint](jt_compensator_input)
            target_joint_delta_actions.append(jt_compensator_output) # bsz x out_size 
            mus.append(self.mu_mlp_modulelist[i_joint](jt_compensator_output))
        target_joint_delta_actions = torch.cat(target_joint_delta_actions, dim=-1)
        mus = torch.cat(mus, dim=-1)
        
        return target_joint_delta_actions, mus


class ProprioAdaptTConv(nn.Module):
    def __init__(self):
        super(ProprioAdaptTConv, self).__init__()
        self.channel_transform = nn.Sequential(
            nn.Linear(16 + 16, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True),
        )
        self.temporal_aggregation = nn.Sequential(
            nn.Conv1d(32, 32, (9,), stride=(2,)),
            nn.ReLU(inplace=True),
            nn.Conv1d(32, 32, (5,), stride=(1,)),
            nn.ReLU(inplace=True),
            nn.Conv1d(32, 32, (5,), stride=(1,)),
            nn.ReLU(inplace=True),
        )
        self.low_dim_proj = nn.Linear(32 * 3, 8)

    def forward(self, x):
        x = self.channel_transform(x)  # (N, 50, 32)
        x = x.permute((0, 2, 1))  # (N, 32, 50)
        x = self.temporal_aggregation(x)  # (N, 32, 3)
        x = self.low_dim_proj(x.flatten(1))
        return x


class ActionCompensator(nn.Module):
    def __init__(self, ):
        super().__init__()
        # set a model with MLP architecture --- 
        # current state, action from real #
        # history state, actually executed actions # -> ouutput
        history_len = 10
        act_dim, obs_dim = 16, 16
        self.delta_action_model = nn.Sequential(
            nn.Linear( (obs_dim + act_dim) * history_len, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, act_dim),
        )
    def forward(self, input_dict):
        delta_action_input = input_dict['delta_action_input']
        # nn_envs x delta_actioon_input_dim #
        delta_action = self.delta_action_model(delta_action_input)
        return delta_action


class ActorCritic(nn.Module):
    def __init__(self, kwargs):
        nn.Module.__init__(self)
        actions_num = kwargs.pop('actions_num')
        input_shape = kwargs.pop('input_shape')
        self.units = kwargs.pop('actor_units')
        self.priv_mlp = kwargs.pop('priv_mlp_units')
        mlp_input_shape = input_shape[0]
        
        self.rollout_w_gt_extrin = kwargs.get('rollout_w_gt_extrin', False)
        self.detach_extrin = kwargs.get('detach_extrin', False)
        
        self.train_action_compensator = False 
        self.per_joint_action_compensator = False

        out_size = self.units[-1]
        self.priv_info = kwargs['priv_info']
        self.priv_info_stage2 = kwargs['proprio_adapt']
        if self.priv_info:
            mlp_input_shape += self.priv_mlp[-1]
            self.env_mlp = MLP(units=self.priv_mlp, input_size=kwargs['priv_info_dim'])

            if self.priv_info_stage2:
                self.adapt_tconv = ProprioAdaptTConv()

        # actor mlp #
        self.actor_mlp = MLP(units=self.units, input_size=mlp_input_shape)
        self.value = torch.nn.Linear(out_size, 1)
        self.mu = torch.nn.Linear(out_size, actions_num)
        self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                fan_out = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out))
                if getattr(m, 'bias', None) is not None:
                    torch.nn.init.zeros_(m.bias)
            if isinstance(m, nn.Linear):
                if getattr(m, 'bias', None) is not None:
                    torch.nn.init.zeros_(m.bias)
        nn.init.constant_(self.sigma, 0)

    @torch.no_grad()
    def act(self, obs_dict):
        # used specifically to collection samples during training
        # it contains exploration so needs to sample from distribution
        mu, logstd, value, _, _ = self._actor_critic(obs_dict)
        sigma = torch.exp(logstd)
        # try:    
        distr = torch.distributions.Normal(mu, sigma)
        selected_action = distr.sample() # 
        # except:
        #     mu = torch.zeros_like(mu)
        #     sigma = torch.zeros_like(sigma) + 0.001
        #     distr = torch.distributions.Normal(mu, sigma)
        #     selected_action = distr.sample() # 
            
        
        result = {
            'neglogpacs': -distr.log_prob(selected_action).sum(1), # self.neglogp(selected_action, mu, sigma, logstd),
            'values': value, # it should be aware of the speed; relative motion controlling #
            'actions': selected_action,
            'mus': mu,
            'sigmas': sigma,
        }
        return result


    @torch.no_grad()
    def act_inference(self, obs_dict):
        # used for testing
        mu, logstd, value, _, _ = self._actor_critic(obs_dict)
        self.value_vals = value
        return mu

    def _actor_critic(self, obs_dict):
        obs = obs_dict['obs']
        extrin, extrin_gt = None, None
        if self.priv_info:
            if self.priv_info_stage2:
                extrin = self.adapt_tconv(obs_dict['proprio_hist'])
                # during supervised training, extrin has gt label
                extrin_gt = self.env_mlp(obs_dict['priv_info']) if 'priv_info' in obs_dict else extrin
                extrin_gt = torch.tanh(extrin_gt)
                extrin = torch.tanh(extrin)
                if self.rollout_w_gt_extrin:
                    # print(f"Rollout with extrin_gt")
                    obs = torch.cat([ obs, extrin_gt ], dim=-1)
                else:
                    if self.detach_extrin:
                        obs = torch.cat([ obs, extrin.detach() ], dim=-1)
                    else:
                        obs = torch.cat([obs, extrin], dim=-1)
            else:
                extrin = self.env_mlp(obs_dict['priv_info'])
                extrin = torch.tanh(extrin)
                obs = torch.cat([obs, extrin], dim=-1)

        if extrin is not None:
            self.extrin = extrin.detach().clone()
            # self.extrin_gt = extrin_gt
        
        ## actor mlp ##
        if self.train_action_compensator and self.per_joint_action_compensator:
            if self.train_action_compensator_uan:
                obs_hist, obs_action = obs[..., : 16 * 20], obs[..., 16 * 20: 16 * 21]
                obs_hist_expanded = obs_hist.contiguous().view(obs.size(0), 20, -1).contiguous()
                nn_envs, hist_len, nn_dof = obs_hist_expanded.size(0), obs_hist_expanded.size(1), obs_hist_expanded.size(2)
                obs_hist_expanded = obs_hist_expanded.contiguous().permute(0, 2, 1).contiguous()
                obs_hist_expanded = obs_hist_expanded.contiguous().view(obs_hist_expanded.size(0) * obs_hist_expanded.size(1), -1).contiguous()
                obs_action = obs_action.contiguous().view(obs_action.size(0) * obs_action.size(1), 1).contiguous()
                concat_obs = torch.cat([obs_hist_expanded, obs_action], dim=-1) # (cat_dim, 20 + 1)
                x = self.actor_mlp(concat_obs) 
                value = self.value(x)
                mu = self.mu(x)
                mu = mu.contiguous().view(nn_envs, nn_dof ).contiguous()
                value = value.contiguous().view(nn_envs, nn_dof, -1).contiguous()
                value = value.mean(dim=1)            
            else:
                obs_hist, obs_action = obs[..., : 16], obs[..., 16: 16 * 2]
                obs_hist_expanded = obs_hist.contiguous().view(obs.size(0), 1, -1).contiguous()
                nn_envs, hist_len, nn_dof = obs_hist_expanded.size(0), obs_hist_expanded.size(1), obs_hist_expanded.size(2)
                obs_hist_expanded = obs_hist_expanded.contiguous().permute(0, 2, 1).contiguous()
                obs_hist_expanded = obs_hist_expanded.contiguous().view(obs_hist_expanded.size(0) * obs_hist_expanded.size(1), -1).contiguous()
                obs_action = obs_action.contiguous().view(obs_action.size(0) * obs_action.size(1), 1).contiguous()
                concat_obs = torch.cat([obs_hist_expanded, obs_action], dim=-1) 
                x = self.actor_mlp(concat_obs) 
                value = self.value(x)
                mu = self.mu(x)
                mu = mu.contiguous().view(nn_envs, nn_dof ).contiguous()
                value = value.contiguous().view(nn_envs, nn_dof, -1).contiguous()
                value = value.mean(dim=1)            
        else:
            x = self.actor_mlp(obs)
            value = self.value(x)
            mu = self.mu(x)
        sigma = self.sigma
        return mu, mu * 0 + sigma, value, extrin, extrin_gt

    def forward(self, input_dict):
        prev_actions = input_dict.get('prev_actions', None)
        rst = self._actor_critic(input_dict)
        mu, logstd, value, extrin, extrin_gt = rst
        sigma = torch.exp(logstd)
        # try:
        distr = torch.distributions.Normal(mu, sigma)
        # except:
        #     mu = torch.zeros_like(mu)
        #     sigma = torch.zeros_like(sigma) + 0.001
        #     distr = torch.distributions.Normal(mu, sigma)
            
        entropy = distr.entropy().sum(dim=-1)
        prev_neglogp = -distr.log_prob(prev_actions).sum(1)
        result = {
            'prev_neglogp': torch.squeeze(prev_neglogp),
            'values': value,
            'entropy': entropy,
            'mus': mu,
            'sigmas': sigma,
            'extrin': extrin,
            'extrin_gt': extrin_gt,
        }
        return result




class ActorCriticMaskCompensator(nn.Module):
    def __init__(self, kwargs):
        nn.Module.__init__(self)
        actions_num = kwargs.pop('actions_num')
        input_shape = kwargs.pop('input_shape')
        self.units = kwargs.pop('actor_units')
        self.priv_mlp = kwargs.pop('priv_mlp_units')
        mlp_input_shape = input_shape[0]
        
        self.rollout_w_gt_extrin = kwargs.get('rollout_w_gt_extrin', False)
        self.detach_extrin = kwargs.get('detach_extrin', False)

        out_size = self.units[-1]
        self.priv_info = kwargs['priv_info']
        self.priv_info_stage2 = kwargs['proprio_adapt']
        if self.priv_info: 
            mlp_input_shape += self.priv_mlp[-1]
            self.env_mlp = MLP(units=self.priv_mlp, input_size=kwargs['priv_info_dim'])

            if self.priv_info_stage2:
                self.adapt_tconv = ProprioAdaptTConv()
        
        
        self.target_joint_idx_tensor = kwargs.get('target_joint_idx_tensor', None)
        self.actor_mlp = MLPMaskedNet(units=self.units, target_joint_idx_tensor=self.target_joint_idx_tensor, actions_num=actions_num // self.target_joint_idx_tensor.size(0))
        full_out_size = self.units[-1] * self.target_joint_idx_tensor.size(0)
        
        # self.actor_mlp = MLP(units=self.units, input_size=mlp_input_shape)
        self.value = torch.nn.Linear(full_out_size, 1) # maps the output features of all joints to the value #
        self.mu = torch.nn.Linear(full_out_size, actions_num) # maps the output features of all joints to number of actions --- the true delta actions of each joint to compensate #
        
        
        self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                fan_out = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out))
                if getattr(m, 'bias', None) is not None:
                    torch.nn.init.zeros_(m.bias)
            if isinstance(m, nn.Linear):
                if getattr(m, 'bias', None) is not None:
                    torch.nn.init.zeros_(m.bias)
        nn.init.constant_(self.sigma, 0)

    @torch.no_grad()
    def act(self, obs_dict):
        # used specifically to collection samples during training
        # it contains exploration so needs to sample from distribution
        mu, logstd, value, _, _ = self._actor_critic(obs_dict)
        sigma = torch.exp(logstd)
        distr = torch.distributions.Normal(mu, sigma)
        selected_action = distr.sample()
        result = {
            'neglogpacs': -distr.log_prob(selected_action).sum(1), # self.neglogp(selected_action, mu, sigma, logstd),
            'values': value, # it should be aware of the speed; relative motion controlling #
            'actions': selected_action,
            'mus': mu,
            'sigmas': sigma,
        }
        return result

    @torch.no_grad()
    def act_inference(self, obs_dict):
        # used for testing
        mu, logstd, value, _, _ = self._actor_critic(obs_dict)
        return mu

    def _actor_critic(self, obs_dict):
        obs = obs_dict['obs']
        extrin, extrin_gt = None, None
        if self.priv_info:
            if self.priv_info_stage2:
                extrin = self.adapt_tconv(obs_dict['proprio_hist'])
                # during supervised training, extrin has gt label
                extrin_gt = self.env_mlp(obs_dict['priv_info']) if 'priv_info' in obs_dict else extrin
                extrin_gt = torch.tanh(extrin_gt)
                extrin = torch.tanh(extrin)
                if self.rollout_w_gt_extrin:
                    # print(f"Rollout with extrin_gt")
                    obs = torch.cat([ obs, extrin_gt ], dim=-1)
                else:
                    if self.detach_extrin:
                        obs = torch.cat([ obs, extrin.detach() ], dim=-1)
                    else:
                        obs = torch.cat([obs, extrin], dim=-1)
            else:
                extrin = self.env_mlp(obs_dict['priv_info'])
                extrin = torch.tanh(extrin)
                obs = torch.cat([obs, extrin], dim=-1)

        if extrin is not None:
            self.extrin = extrin.detach().clone()
            # self.extrin_gt = extrin_gt
        
        x, mu = self.actor_mlp(obs)
        value = self.value(x)
        # mu = self.mu(x)
        sigma = self.sigma
        return mu, mu * 0 + sigma, value, extrin, extrin_gt

    def forward(self, input_dict):
        prev_actions = input_dict.get('prev_actions', None)
        rst = self._actor_critic(input_dict)
        mu, logstd, value, extrin, extrin_gt = rst
        sigma = torch.exp(logstd)
        distr = torch.distributions.Normal(mu, sigma)
        entropy = distr.entropy().sum(dim=-1)
        prev_neglogp = -distr.log_prob(prev_actions).sum(1)
        result = {
            'prev_neglogp': torch.squeeze(prev_neglogp),
            'values': value,
            'entropy': entropy,
            'mus': mu,
            'sigmas': sigma,
            'extrin': extrin,
            'extrin_gt': extrin_gt,
        }
        return result


from hora.algo.models.running_mean_std import RunningMeanStd


# asymmetric bc #
# asymmetric BC #
class ActorCriticAsymmetricBC(nn.Module):
    def __init__(self, kwargs):
        nn.Module.__init__(self)
        actions_num = kwargs.pop('actions_num')
        input_shape = kwargs.pop('input_shape')
        self.units = kwargs.pop('actor_units')
        self.priv_mlp = kwargs.pop('priv_mlp_units')
        mlp_input_shape = input_shape[0]
        
        mlp_input_shape = 357
        self.original_obs_dim = 357
        
        self.rollout_w_gt_extrin = kwargs.get('rollout_w_gt_extrin', False)
        self.detach_extrin = kwargs.get('detach_extrin', False)

        out_size = self.units[-1]
        
        #### Priv info setting ####
        self.priv_info = kwargs['priv_info']
        self.priv_info_stage2 = kwargs['proprio_adapt']
        if self.priv_info:
            mlp_input_shape += self.priv_mlp[-1]
            self.env_mlp = MLP(units=self.priv_mlp, input_size=kwargs['priv_info_dim'])

            if self.priv_info_stage2:
                self.adapt_tconv = ProprioAdaptTConv()
        self.priv_info = False 
        #### Priv info setting ####
        
        # load config #
        self.invdyn_v2_config_path = 'controlseq.yml'
        configs_folder = "../IsaacGymEnvs2/isaacgymenvs/ddim/configs"
        with open(os.path.join(configs_folder, self.invdyn_v2_config_path), "r") as f:
            config = yaml.safe_load(f)
        invdyn_config = dict2namespace(config)
        invdyn_config.device = 'cuda'
        invdyn_config.invdyn.model_arch = 'resmlp'
        invdyn_config.invdyn.res_blocks = 2
        invdyn_config.invdyn.pred_extrin = False
        
        invdyn_config.invdyn.history_length = 10
        invdyn_config.invdyn.future_length = 2
        invdyn_config.invdyn.res_blocks = 5
        
        invdyn_config.invdyn.future_ref_dim = 3
        invdyn_config.invdyn.pred_extrin = False 
        invdyn_config.invdyn.train_value_network = False
        
        self.actor_mlp = ModelInvDyn(invdyn_config)
        #### Load actor ####
        
        
        #### Load critic ####
        invdyn_config.invdyn.train_value_network = True
        self.value_net = ModelInvDyn(invdyn_config)
        #### Load critic ####
        
        self.value_actor_mlp = MLP(units=self.units, input_size=mlp_input_shape)
        self.value = torch.nn.Linear(self.units[-1], 1)
        
        
        self.tune_bc_via_compensator_model = kwargs.get('tune_bc_via_compensator_model', False)
        print(f"tune_bc_via_compensator_model: {self.tune_bc_via_compensator_model}")
        if self.tune_bc_via_compensator_model:
            bc_compensator_model_input_dim = 16 + (16 + 16) * invdyn_config.invdyn.history_length
            self.bc_compensator_model = MLP(units=self.units, input_size=bc_compensator_model_input_dim)
            self.bc_compensator_running_mean_std = RunningMeanStd((bc_compensator_model_input_dim, )) # running mean std of the model
            # --- for mus output --- #
            self.bc_compensator_out_mus = torch.nn.Linear(self.units[-1], actions_num)
            # mus output #
            
            # bc compensatorjmodel -- should be compensator model input dim  # model.load useufl infor is enough #
        
        self.allegro_dof_lower = torch.from_numpy(np.array([
            -0.3140, -1.0470, -0.5060, -0.3660, -0.3490, -0.4700, -1.2000, -1.3400,
            -0.3140, -1.0470, -0.5060, -0.3660, -0.3140, -1.0470, -0.5060, -0.3660            
        ])).cuda()
        self.allegro_dof_upper = torch.from_numpy(np.array([
            2.2300, 1.0470, 1.8850, 2.0420, 2.0940, 2.4430, 1.9000, 1.8800, 2.2300,
            1.0470, 1.8850, 2.0420, 2.2300, 1.0470, 1.8850, 2.0420
        ])).cuda()
        
        
        self.sigma_version = 1
        self.sigma_version = 2
        # self.sigma_version = 3
        
        
        self.sigma_scale_factor = 0.00001 
        
        # self.actor_mlp = MLP(units=self.units, input_size=mlp_input_shape)
        # self.value = torch.nn.Linear(out_size, 1)
        # self.mu = torch.nn.Linear(out_size, actions_num)
        self.sigma = nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), requires_grad=True)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                fan_out = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out))
                if getattr(m, 'bias', None) is not None:
                    torch.nn.init.zeros_(m.bias)
            if isinstance(m, nn.Linear):
                if getattr(m, 'bias', None) is not None:
                    torch.nn.init.zeros_(m.bias)
        nn.init.constant_(self.sigma, 0)



    def load_value_actor_mlp_and_value(self, ckpt_fn):
        states = torch.load(ckpt_fn)['model']
        actor_mlp_state = {
            k.replace('actor_mlp.', ''): v for k, v in states.items() if 'actor_mlp' in k
        }
        value_state = {
            k.replace('value.', ''): v for k, v in states.items() if 'value' in k
        }
        self.value_actor_mlp.load_state_dict(actor_mlp_state, strict=True)
        self.value.load_state_dict(value_state, strict=True)


    # load 
    def load_actor_models(self, ckpt_fn ):
        states = torch.load(ckpt_fn)
        
        one_key = list(states[0].keys())[0]
        if 'module.' in one_key:
            new_key_to_weights = {}
            for cur_key in states[0]:
                cur_key_wo_module = cur_key[len('module.'): ]
                new_key_to_weights[cur_key_wo_module] = states[0][cur_key]
            new_ema_key_to_weights = {}
            for cur_key in states[-1]:
                cur_key_wo_module = cur_key[len('module.'): ]
                new_ema_key_to_weights[cur_key_wo_module] = states[-1][cur_key]
        else:
            new_key_to_weights = states[0]
            new_ema_key_to_weights = states[-1]
            
        
        self.actor_mlp.load_state_dict(new_key_to_weights, strict=True) 
        
    
    def load_critic_models (self, ckpt_fn):
        states = torch.load(ckpt_fn)
        
        one_key = list(states[0].keys())[0]
        if 'module.' in one_key:
            new_key_to_weights = {}
            for cur_key in states[0]:
                cur_key_wo_module = cur_key[len('module.'): ]
                new_key_to_weights[cur_key_wo_module] = states[0][cur_key]
            new_ema_key_to_weights = {}
            for cur_key in states[-1]:
                cur_key_wo_module = cur_key[len('module.'): ]
                new_ema_key_to_weights[cur_key_wo_module] = states[-1][cur_key]
        else:
            new_key_to_weights = states[0]
            new_ema_key_to_weights = states[-1]
            
        
        self.value_net.load_state_dict(new_key_to_weights, strict=True)
    
    
    @torch.no_grad()
    def act(self, obs_dict):
        # used specifically to collection samples during training
        # it contains exploration so needs to sample from distribution 
        mu, logstd, value, _, _ = self._actor_critic(obs_dict) 
        
        
        if not self.tune_bc_via_compensator_model:
            if self.sigma_version == 1:
                sigma = torch.exp(logstd) * 1. / 24  * (2.0 / (self.allegro_dof_upper - self.allegro_dof_lower))
            elif self.sigma_version == 2:
                sigma = torch.exp(logstd) * 1. / 50 # -- 0.02 -- sclae of the sigma version #
            elif self.sigma_version == 3:
                sigma = torch.exp(logstd) * self.sigma_scale_factor
        else:
            sigma = torch.exp(logstd)
        
        distr = torch.distributions.Normal(mu, sigma)
        selected_action = distr.sample()
        result = {
            'neglogpacs': -distr.log_prob(selected_action).sum(1), # self.neglogp(selected_action, mu, sigma, logstd),
            'values': value, # it should be aware of the speed; relative motion controlling #
            'actions': selected_action,
            'mus': mu,
            'sigmas': sigma,
        }
        return result

    @torch.no_grad()
    def act_inference(self, obs_dict): # value function
        # used for testing
        mu, logstd, value, _, _ = self._actor_critic(obs_dict)
        self.value_vals = value
                
        return mu

    def _actor_critic(self, obs_dict):
        # obs = obs_dict['obs']
        
        actor_obs = obs_dict['actor_obs']
        critic_obs = obs_dict['critic_obs']
        actor_future_ref = torch.zeros((actor_obs.size(0), 6), dtype=torch.float32).cuda()
        
        
        extrin, extrin_gt = None, None
        # if self.priv_info:
        #     if self.priv_info_stage2:
        #         extrin = self.adapt_tconv(obs_dict['proprio_hist'])
        #         # during supervised training, extrin has gt label
        #         extrin_gt = self.env_mlp(obs_dict['priv_info']) if 'priv_info' in obs_dict else extrin
        #         extrin_gt = torch.tanh(extrin_gt)
        #         extrin = torch.tanh(extrin)
        #         if self.rollout_w_gt_extrin:
        #             # print(f"Rollout with extrin_gt")
        #             obs = torch.cat([ obs, extrin_gt ], dim=-1)
        #         else:
        #             if self.detach_extrin:
        #                 obs = torch.cat([ obs, extrin.detach() ], dim=-1)
        #             else:
        #                 obs = torch.cat([obs, extrin], dim=-1)
        #     else:
        extrin = self.env_mlp(obs_dict['priv_info'])
        extrin = torch.tanh(extrin)
        ori_obs = torch.cat([obs_dict['ori_obs'], extrin], dim=-1)

        
            
        x = self.actor_mlp(actor_obs, actor_future_ref) # nn_envs x 16
        x = x[..., : 16]
        
        
        # TODO: weights of actor_mlp sould not be included in the optimizer #
        if self.tune_bc_via_compensator_model:
            self.bc_model_actions = x.clone() # TODO: pass the bc_model_actions into the environment #
            bc_compensator_model_input = torch.cat(
                [ actor_obs, x.detach()], dim=-1
            )
            bc_compensator_model_input = self.bc_compensator_running_mean_std(bc_compensator_model_input)
            bc_compensator_model_output = self.bc_compensator_model(bc_compensator_model_input)
            # bc_compensator_model_output = self.bc_compensator_out_mus(bc_compensator_model_output)
            # x = bc_compensator_model_output 
            
            x = bc_compensator_model_output
            value = self.value(self.value_actor_mlp(ori_obs)) # asymmetric 
            
            mu = self.bc_compensator_out_mus(x)
            sigma = self.sigma
        else:
            # ### BC value model ###
            # value = self.value_net(critic_obs, x) # nn_envs x 1 #
            # ### BC value model ###
            
            ### RL value model ###
            value = self.value(self.value_actor_mlp(ori_obs))
            ### RL value model ###
            
            x = unscale(x, self.allegro_hand_dof_lower_limits, self.allegro_hand_dof_upper_limits)
            
            sigma = self.sigma
            mu = x
        
        return mu, mu * 0 + sigma, value, extrin, extrin_gt

    def forward(self, input_dict):
        prev_actions = input_dict.get('prev_actions', None)
        
        obs = input_dict['obs']
        actor_obs, critic_obs = obs[..., : self.actor_model_input_dim], obs[..., self.actor_model_input_dim : ]
        input_dict.update(
            {
                'actor_obs': actor_obs,
                'critic_obs': critic_obs,
            }
        )
        
        # ### log std ### #
        rst = self._actor_critic(input_dict)
        mu, logstd, value, extrin, extrin_gt = rst
        
        if not self.tune_bc_via_compensator_model:
            if self.sigma_version == 1:
                sigma = torch.exp(logstd) * 1. / 24  * (2.0 / (self.allegro_dof_upper - self.allegro_dof_lower))
            elif self.sigma_version == 2:
                sigma = torch.exp(logstd) * 1. / 50  ### -- 0.02 -- sclae of the sigma version ###
            elif self.sigma_version == 3:
                sigma = torch.exp(logstd) * self.sigma_scale_factor
        else:
            sigma = torch.exp(logstd)
        
        distr = torch.distributions.Normal(mu, sigma)
        entropy = distr.entropy().sum(dim=-1)
        prev_neglogp = -distr.log_prob(prev_actions).sum(1)
        result = {
            'prev_neglogp': torch.squeeze(prev_neglogp),
            'values': value,
            'entropy': entropy,
            'mus': mu,
            'sigmas': sigma,
            'extrin': extrin,
            'extrin_gt': extrin_gt,
        }
        return result


